{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "第08课：从自然语言处理角度看 HMM 和 CRF"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "HMM 模型是由一个“五元组”组成的集合：\n",
    "\n",
    "StatusSet：状态值集合，状态值集合为 (B, M, E, S)，其中 B 为词的首个字，M 为词中间的字，E 为词语中最后一个字，S 为单个字，B、M、E、S 每个状态代表的是该字在词语中的位置。\n",
    "\n",
    "举个例子，对“中国的人工智能发展进入高潮阶段”，分词可以标注为：“中B国E的S人B工E智B能E发B展E进B入E高B潮E阶B段E”，最后的分词结果为：['中国', '的', '人工', '智能', '发展', '进入', '高潮', '阶段']。\n",
    "\n",
    "ObservedSet：观察值集合，观察值集合就是所有语料的汉字，甚至包括标点符号所组成的集合。\n",
    "\n",
    "TransProbMatrix：转移概率矩阵，状态转移概率矩阵的含义就是从状态 X 转移到状态 Y 的概率，是一个4×4的矩阵，即 {B,E,M,S}×{B,E,M,S}。\n",
    "\n",
    "EmitProbMatrix：发射概率矩阵，发射概率矩阵的每个元素都是一个条件概率，代表 P(Observed[i]|Status[j]) 概率。\n",
    "\n",
    "InitStatus：初始状态分布，初始状态概率分布表示句子的第一个字属于 {B,E,M,S} 这四种状态的概率。\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "import pickle\n",
    "import json "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "STATES = {'B', 'M', 'E', 'S'}\n",
    "EPS = 0.0001\n",
    "#定义停顿标点\n",
    "seg_stop_words = {\" \",\"，\",\"。\",\"“\",\"”\",'“', \"？\", \"！\", \"：\", \"《\", \"》\", \"、\", \"；\", \"·\", \"‘ \", \"’\", \"──\", \",\", \".\", \"?\", \"!\", \"`\", \"~\", \"@\", \"#\", \"$\", \"%\", \"^\", \"&\", \"*\", \"(\", \")\", \"-\", \"_\", \"+\", \"=\", \"[\", \"]\", \"{\", \"}\", '\"', \"'\", \"<\", \">\", \"\\\\\", \"|\" \"\\r\", \"\\n\",\"\\t\"}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "其中的数据结构定义：\n",
    "\n",
    "trans_mat：状态转移矩阵，trans_mat[state1][state2] 表示训练集中由 state1 转移到 state2 的次数。\n",
    "\n",
    "emit_mat：观测矩阵，emit_mat[state][char] 表示训练集中单字 char 被标注为 state 的次数。\n",
    "\n",
    "init_vec：初始状态分布向量，init_vec[state] 表示状态 state 在训练集中出现的次数。\n",
    "\n",
    "state_count：状态统计向量，state_count[state]表示状态 state 出现的次数。\n",
    "\n",
    "word_set：词集合，包含所有单词。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "class HMM_Model:\n",
    "    def __init__(self):\n",
    "            self.trans_mat = {}  \n",
    "            self.emit_mat = {} \n",
    "            self.init_vec = {}  \n",
    "            self.state_count = {} \n",
    "            self.states = {}\n",
    "            self.inited = False\n",
    "            \n",
    "     #初始化数据结构    \n",
    "    def setup(self):\n",
    "        for state in self.states:\n",
    "            # build trans_mat\n",
    "            self.trans_mat[state] = {}\n",
    "            for target in self.states:\n",
    "                self.trans_mat[state][target] = 0.0\n",
    "            self.emit_mat[state] = {}\n",
    "            self.init_vec[state] = 0\n",
    "            self.state_count[state] = 0\n",
    "        self.inited = True\n",
    "        \n",
    "        \n",
    "    #模型保存   \n",
    "    def save(self, filename=\"../data/08/hmm.json\", code='json'):\n",
    "        fw = open(filename, 'w', encoding='utf-8')\n",
    "        data = {\n",
    "            \"trans_mat\": self.trans_mat,\n",
    "            \"emit_mat\": self.emit_mat,\n",
    "            \"init_vec\": self.init_vec,\n",
    "            \"state_count\": self.state_count\n",
    "        }\n",
    "        if code == \"json\":\n",
    "            txt = json.dumps(data)\n",
    "            txt = txt.encode('utf-8').decode('unicode-escape')\n",
    "            fw.write(txt)\n",
    "        elif code == \"pickle\":\n",
    "            pickle.dump(data, fw)\n",
    "        fw.close()\n",
    "        \n",
    "        \n",
    "     #模型加载\n",
    "    def load(self, filename=\"../data/08/hmm.json\", code=\"json\"):\n",
    "        fr = open(filename, 'r', encoding='utf-8')\n",
    "        if code == \"json\":\n",
    "            txt = fr.read()\n",
    "            model = json.loads(txt)\n",
    "        elif code == \"pickle\":\n",
    "            model = pickle.load(fr)\n",
    "        self.trans_mat = model[\"trans_mat\"]\n",
    "        self.emit_mat = model[\"emit_mat\"]\n",
    "        self.init_vec = model[\"init_vec\"]\n",
    "        self.state_count = model[\"state_count\"]\n",
    "        self.inited = True\n",
    "        fr.close()\n",
    "        \n",
    "    def do_train(self, observes, states):\n",
    "        if not self.inited:\n",
    "            self.setup()\n",
    "\n",
    "        for i in range(len(states)):\n",
    "            if i == 0:\n",
    "                self.init_vec[states[0]] += 1\n",
    "                self.state_count[states[0]] += 1\n",
    "            else:\n",
    "                self.trans_mat[states[i - 1]][states[i]] += 1\n",
    "                self.state_count[states[i]] += 1\n",
    "                if observes[i] not in self.emit_mat[states[i]]:\n",
    "                    self.emit_mat[states[i]][observes[i]] = 1\n",
    "                else:\n",
    "                    self.emit_mat[states[i]][observes[i]] += 1\n",
    "                    \n",
    "                    \n",
    "    #频数转频率\n",
    "    def get_prob(self):\n",
    "        init_vec = {}\n",
    "        trans_mat = {}\n",
    "        emit_mat = {}\n",
    "        default = max(self.state_count.values())  \n",
    "\n",
    "        for key in self.init_vec:\n",
    "            if self.state_count[key] != 0:\n",
    "                init_vec[key] = float(self.init_vec[key]) / self.state_count[key]\n",
    "            else:\n",
    "                init_vec[key] = float(self.init_vec[key]) / default\n",
    "\n",
    "        for key1 in self.trans_mat:\n",
    "            trans_mat[key1] = {}\n",
    "            for key2 in self.trans_mat[key1]:\n",
    "                if self.state_count[key1] != 0:\n",
    "                    trans_mat[key1][key2] = float(self.trans_mat[key1][key2]) / self.state_count[key1]\n",
    "                else:\n",
    "                    trans_mat[key1][key2] = float(self.trans_mat[key1][key2]) / default\n",
    "\n",
    "        for key1 in self.emit_mat:\n",
    "            emit_mat[key1] = {}\n",
    "            for key2 in self.emit_mat[key1]:\n",
    "                if self.state_count[key1] != 0:\n",
    "                    emit_mat[key1][key2] = float(self.emit_mat[key1][key2]) / self.state_count[key1]\n",
    "                else:\n",
    "                    emit_mat[key1][key2] = float(self.emit_mat[key1][key2]) / default\n",
    "        return init_vec, trans_mat, emit_mat\n",
    "    \n",
    "#     预测采用 Viterbi 算法求得最优路径\n",
    " #模型预测\n",
    "    def do_predict(self, sequence):\n",
    "        tab = [{}]\n",
    "        path = {}\n",
    "        init_vec, trans_mat, emit_mat = self.get_prob()\n",
    "\n",
    "        # 初始化\n",
    "        for state in self.states:\n",
    "            tab[0][state] = init_vec[state] * emit_mat[state].get(sequence[0], EPS)\n",
    "            path[state] = [state]\n",
    "\n",
    "        # 创建动态搜索表\n",
    "        for t in range(1, len(sequence)):\n",
    "            tab.append({})\n",
    "            new_path = {}\n",
    "            for state1 in self.states:\n",
    "                items = []\n",
    "                for state2 in self.states:\n",
    "                    if tab[t - 1][state2] == 0:\n",
    "                        continue\n",
    "                    prob = tab[t - 1][state2] * trans_mat[state2].get(state1, EPS) * emit_mat[state1].get(sequence[t], EPS)\n",
    "                    items.append((prob, state2))\n",
    "                best = max(items)  \n",
    "                tab[t][state1] = best[0]\n",
    "                new_path[state1] = path[best[1]] + [state1]\n",
    "            path = new_path\n",
    "\n",
    "        # 搜索最有路径\n",
    "        prob, state = max([(tab[len(sequence) - 1][state], state) for state in self.states])\n",
    "        return path[state]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    " def get_tags(src):\n",
    "        tags = []\n",
    "        if len(src) == 1:\n",
    "            tags = ['S']\n",
    "        elif len(src) == 2:\n",
    "            tags = ['B', 'E']\n",
    "        else:\n",
    "            m_num = len(src) - 2\n",
    "            tags.append('B')\n",
    "            tags.extend(['M'] * m_num)\n",
    "            tags.append('E')\n",
    "        return tags"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def cut_sent(src, tags):\n",
    "        word_list = []\n",
    "        start = -1\n",
    "        started = False\n",
    "\n",
    "        if len(tags) != len(src):\n",
    "            return None\n",
    "\n",
    "        if tags[-1] not in {'S', 'E'}:\n",
    "            if tags[-2] in {'S', 'E'}:\n",
    "                tags[-1] = 'S'  \n",
    "            else:\n",
    "                tags[-1] = 'E'  \n",
    "\n",
    "        for i in range(len(tags)):\n",
    "            if tags[i] == 'S':\n",
    "                if started:\n",
    "                    started = False\n",
    "                    word_list.append(src[start:i])  \n",
    "                word_list.append(src[i])\n",
    "            elif tags[i] == 'B':\n",
    "                if started:\n",
    "                    word_list.append(src[start:i])  \n",
    "                start = i\n",
    "                started = True\n",
    "            elif tags[i] == 'E':\n",
    "                started = False\n",
    "                word = src[start:i+1]\n",
    "                word_list.append(word)\n",
    "            elif tags[i] == 'M':\n",
    "                continue\n",
    "        return word_list"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    " class HMMSoyoger(HMM_Model):\n",
    "        def __init__(self, *args, **kwargs):\n",
    "                super(HMMSoyoger, self).__init__(*args, **kwargs)\n",
    "                self.states = STATES\n",
    "                self.data = None\n",
    "                \n",
    "     #加载语料\n",
    "        def read_txt(self, filename):\n",
    "            self.data = open(filename, 'r', encoding=\"utf-8\")\n",
    "            \n",
    "        def train(self):\n",
    "            if not self.inited:\n",
    "                self.setup()\n",
    "\n",
    "            for line in self.data:\n",
    "                line = line.strip()\n",
    "                if not line:\n",
    "                    continue\n",
    "\n",
    "               #观测序列\n",
    "                observes = []\n",
    "                for i in range(len(line)):\n",
    "                    if line[i] == \" \":\n",
    "                        continue\n",
    "                    observes.append(line[i])\n",
    "\n",
    "                #状态序列\n",
    "                words = line.split(\" \")  \n",
    "\n",
    "                states = []\n",
    "                for word in words:\n",
    "                    if word in seg_stop_words:\n",
    "                        continue\n",
    "                    states.extend(get_tags(word))\n",
    "                #开始训练\n",
    "                if(len(observes) >= len(states)):\n",
    "                    self.do_train(observes, states)\n",
    "                else:\n",
    "                    pass\n",
    "                \n",
    "        def lcut(self, sentence):\n",
    "            try:\n",
    "                tags = self.do_predict(sentence)\n",
    "                return cut_sent(sentence, tags)\n",
    "            except:\n",
    "                return sentence"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "soyoger = HMMSoyoger()\n",
    "soyoger.read_txt(\"../data/08/syj_trainCorpus_utf8.txt\")\n",
    "soyoger.train()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['中国', '的', '人工', '智能', '发展', '进入', '高潮', '阶段', '。']"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "soyoger.lcut(\"中国的人工智能发展进入高潮阶段。\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "import genius\n",
    "text = u\"\"\"中文自然语言处理是人工智能技术的一个重要分支。\"\"\"\n",
    "seg_list = genius.seg_text(\n",
    "    text,\n",
    "    use_combine=True,\n",
    "    use_pinyin_segment=True,\n",
    "    use_tagging=True,\n",
    "    use_break=True\n",
    ")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[<genius.word.Word at 0x105645588>,\n",
       " <genius.word.Word at 0x1273c1a90>,\n",
       " <genius.word.Word at 0x1273c19b0>,\n",
       " <genius.word.Word at 0x1273c1b00>,\n",
       " <genius.word.Word at 0x1273c1ac8>,\n",
       " <genius.word.Word at 0x1273c1a58>,\n",
       " <genius.word.Word at 0x1273c1b38>,\n",
       " <genius.word.Word at 0x1273c1ba8>,\n",
       " <genius.word.Word at 0x1273c1be0>,\n",
       " <genius.word.Word at 0x1273c1b70>,\n",
       " <genius.word.Word at 0x105645b70>]"
      ]
     },
     "execution_count": 18,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "seg_list"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'中文'"
      ]
     },
     "execution_count": 20,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "seg_list[0].text"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['中文', '自然语言', '处理', '是', '人工智能', '技术', '的', '一个', '重要', '分支', '。']"
      ]
     },
     "execution_count": 24,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "[word.text  for word in seg_list]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "其中，genius.seg_text 函数接受5个参数，其中 text 是必填参数：\n",
    "\n",
    "text 第一个参数为需要分词的字。\n",
    "use_break 代表对分词结构进行打断处理，默认值 True。\\n\n",
    "use_combine 代表是否使用字典进行词合并，默认值 False。\\n\n",
    "use_tagging 代表是否进行词性标注，默认值 True。\\n\n",
    "use_pinyin_segment 代表是否对拼音进行分词处理，默认值 True。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
